# -*- coding: utf-8 -*- 
"""
========================================================================================================================
@project : my-sanic
@file: BaseModel
@Author: mengying
@email: 652044581@qq.com
@date: 2023/3/16 23:05
@desc: mongodb数据库操作的基类
========================================================================================================================
"""

from typing import Dict, List, Union
from database import commons
from addict import Dict as asDict
from utils.myTimeFormat import MyTime
from bson import ObjectId


class Transform:
    """将数据里面含有object的id转换为字符串"""

    @staticmethod
    def convert_string(data: Union[Dict, List]) -> Union[Dict, List]:
        # 判断传入的数据不为空
        if not data:
            return {} if data is None else data

        if isinstance(data, list):
            for item in data:
                if item.get("_id", None):
                    item["_id"] = str(item["_id"])
        else:
            if data.get("_id", None):
                data["_id"] = str(data["_id"])
        return data


class BaseMongoMethod:
    _model_class = None
    _create_time_auto = True
    _update_time_auto = True

    @property
    def _db(self):
        return commons.db

    def _dump(self, data) -> dict:
        return self._model_class().dump(data)

    @property
    def collection_name(self) -> str:
        if not self._model_class:
            raise Exception("请定义您的model_class")

        meta_class = getattr(self._model_class, "Meta", None)

        if not meta_class:
            raise Exception("请定义您的model_class中的meta数据")

        return meta_class.collection_name


class AsyncMongoMethod(BaseMongoMethod):

    async def insert_one(self, data: Dict, validation: Dict = None) -> Union[str, None]:
        """
        @parma: data字典数据
        @parma: validation 校验条件
        @desc: 查询单条数据
        """
        # 从子类中加载配置的模板
        model_template = self._dump(data)

        if validation:
            threshold = await self.find_one(validation)
            if threshold:
                return None

        if self._create_time_auto:
            model_template.update({"create_time": MyTime.TimeFormat()})

        result = await self._db[self.collection_name].insert_one(model_template)

        return str(result.inserted_id)

    async def find_one(self, condition: Dict, field: List = None, exclude: List = None) -> Dict:
        """
        @params: condition  查询条件
        @params: field  查询过滤字段
        @desc: 查询单条数据
        """
        field_dict = {}
        if field or exclude:
            field_dict.update({key: 1 for key in field}) if field else print
            field_dict.update({key: 0 for key in exclude}) if exclude else print
            data = await self._db[self.collection_name].find_one(condition, field_dict)
        else:
            data = await self._db[self.collection_name].find_one(condition)

        return Transform.convert_string(data)

    async def insert_many(self, collection_list: List[Dict]) -> List:
        """
        @params: collection_list 字典数据的列表
        @desc:批量插入数据
        """
        container = []
        for collection in collection_list:
            model_template = self._dump(collection)

            if self._create_time_auto:
                model_template['create_time'] = MyTime.TimeFormat()

            container.append(model_template)

        result = await self._db[self.collection_name].insert_many(container)
        return [str(_id) for _id in result.inserted_ids]

    async def find(self, condition: Dict, field: List = None, exclude: List = None, page: int = 1,
                   size: int = 0, sort_list: List = None, total_count: bool = False,
                   ) -> Union[Dict, List]:
        """
        @params: condition  查询条件
        @params: field  查询过滤字段 默认为全部查询
        @params: page 页码
        @params: size 每页大小， 默认0 不分页
        @params: sort_list  排序列表   [field(正序), -field(倒序)]
        @params: total_count 是否返回查询到的数据总量,分页前
        @desc: 查询全部数据
        """
        field_dict = {}
        if field or exclude:
            field_dict.update({key: 1 for key in field}) if field else print
            field_dict.update({key: 0 for key in exclude}) if exclude else print
            cursor = self._db[self.collection_name].find(condition, field_dict)
        else:
            cursor = self._db[self.collection_name].find(condition)

        # 排序
        start = (page - 1) * size
        if sort_list:
            sort_key = []
            for key in sort_list:
                s = key[:1]
                if s == '-':
                    sort_key.append((key[1:], -1))
                else:
                    sort_key.append((key, 1))
            cursor = cursor.sort(sort_key)

        # 分页
        if size > 0:
            cursor = cursor.skip(start).limit(size)

        data_list = []
        async for document in cursor:
            data_list.append(document)

        # object类型字段转换类型
        data_list = Transform.convert_string(data_list)

        if total_count:
            my_dict = asDict()
            my_dict.data = data_list
            my_dict.total = await self.count(condition)
            return my_dict.to_dict()

        return data_list

    async def count(self, condition: Dict) -> int:
        """
        @params: condition  查询条件
        @desc: 统计符合条件的文件数量
        """
        return await self._db[self.collection_name].count_documents(condition)

    async def update(self, condition: Dict, update: Dict, upsert: bool = False) -> int:
        """
        @params: condition  查询条件
        @params: update  更新内容 {"$push" :{"name": "xiaoming"} 也可以 {"name": "xiaoming"}
        @params: upsert  如果没有更新就插入
        @desc: 更新符合条件的文档
        """
        # 分离自定义操作符和数据
        orders = list(update.keys())
        if str(orders[0]).startswith("$"):
            cmd = orders[0]
            update = list(update.values())[0]
        else:
            cmd = "$set"

        if self._update_time_auto:
            update["update_time"] = MyTime.TimeFormat()

        result = await self._db[self.collection_name].update_many(condition, {cmd: update}, upsert=upsert)
        return result.matched_count

    async def delete(self, condition: Dict) -> int:
        """
        @params: condition  查询条件
        @desc: 删除符合条件的文档
        """
        result = await self._db[self.collection_name].delete_many(condition)
        return result.deleted_count

    async def aggregate(self, pipeline: Dict) -> Union[Dict, List]:
        """
        @params: pipeline复杂得管道查询，可以实现联表查询功能
        @desc: 聚合查询功能
        """
        result = await self._db[self.collection_name].aggregate(pipeline)
        return list(result)

    async def backups(self, condition: Dict, node: str = "default") -> Union[str, None]:
        """
        @params: condition  查询条件
        @params: node  备份的节点数据
        @desc: 数据备份-单条
        """
        backups_data = await self.find_one(condition=condition)

        collection_name_backups = "_".join([self.collection_name, node, "backups"])

        if not backups_data:
            return None

        condition = {
            "_id": ObjectId(backups_data["_id"])
        }
        await self._db[collection_name_backups].delete_many(condition)
        backups_data["_id"] = ObjectId(backups_data["_id"])
        return await self._db[collection_name_backups].insert_one(backups_data)

    async def restore(self, condition: Dict, node: str = "default") -> Union[str, None]:
        """
        @params: condition  查询条件
        @desc: 数据还原
        """
        collection_name_backups = "_".join([self.collection_name, node, "backups"])
        restore_data = await self._db[collection_name_backups].find_one(condition)

        if not restore_data:
            return None

        condition = {
            "_id": ObjectId(restore_data["_id"])
        }
        await self._db[self.collection_name].delete_many(condition)
        restore_data["_id"] = ObjectId(restore_data["_id"])
        return await self._db[self.collection_name].insert_one(restore_data)


